import os
import pandas as pd
import numpy as np
import torch
import os
from torch.utils.data import TensorDataset, DataLoader



class Recommender_Dataset(object):

    def __init__(self, config):
        self.config = config
        self.dataset = config['dataset']
        self.t_field = config['t_field']
        self.x_field = config['x_field']
        self.v_field = config['v_field']
        self.uid_field = config['uid_field']
        self.vid_field = config['vid_field']
        self.yf_field = config['yf_field']
        self.ycf_field = config['ycf_field']

        self.t_mu = 'mu1'
        self.c_mu = 'mu0'

        self._load_feat(os.path.join(self.config['rootPath'],'dataset/{}/{}.csv'.format(
            self.dataset,self.dataset)))

        self._add_weight()

        self.config['treated_ratio'] = sum(self.feat['treatment'] == 1) / len(self)
        self.config['control_ratio'] = sum(self.feat['treatment'] == 0) / len(self)
        self.config['n_unit'] = len(self)
        self.config['x_n_covariate'] = len(self.feat[self.x_field].columns)
        self.config['v_n_covariate'] = len(self.feat[self.v_field].columns)
        self.config['start_order'] = 0
        self.config['end_order'] = 100

    def _add_weight(self):


        self.val_treated = pd.DataFrame({
            'treatment': np.ones(len(self.val)),
            'yf': self.val[self.t_mu].values,
        })
        self.val_control = pd.DataFrame({
            'treatment': np.zeros(len(self.val)),
            'yf': self.val[self.c_mu].values
        })

        self.val_treated = pd.concat([self.val_treated, self.val[self.x_field+self.v_field]], axis=1)
        self.val_control = pd.concat([self.val_control, self.val[self.x_field+self.v_field]], axis=1)

        self.test_treated = pd.DataFrame({
            'treatment':np.ones(len(self.test)),
            'yf':self.test[self.t_mu].values,
        })
        self.test_control = pd.DataFrame({
            'treatment': np.zeros(len(self.test)),
            'yf': self.test[self.c_mu].values
        })

        self.test_treated = pd.concat([self.test_treated, self.test[self.x_field+self.v_field]], axis=1)
        self.test_control = pd.concat([self.test_control, self.test[self.x_field+self.v_field]], axis=1)

    def _load_feat(self,feat_path):

        df = pd.read_csv(feat_path,)
        df = df.fillna(method="ffill")
        self.feat = df

        self.n_units = len(self.feat)
        self.config['n_units'] = self.n_units

        splits = self.config['splits'].strip().split('/')
        n_train, n_val, n_test = float(splits[0]), float(splits[1]), float(splits[2])
        feat_index = set(range(0, self.n_units))

        train_index = list(np.random.choice(list(feat_index), int(n_train * self.n_units), replace=False))
        val_index = list(
            np.random.choice(list(feat_index - set(train_index)), int(n_val * self.n_units), replace=False))
        test_index = list(feat_index - set(train_index) - set(val_index))

        self.train = self.feat.iloc[train_index].reset_index(drop=True)
        self.val = self.feat.iloc[val_index].reset_index(drop=True)
        self.test = self.feat.iloc[test_index].reset_index(drop=True)


    def __getitem__(self, index):
        df = self.feat[index]
        return df

    def __len__(self):
        return len(self.feat)

    def __repr__(self):
        return self.__str__()

    def __str__(self):
        info = ['[{}]'.format(self.dataset)]
        info.append('The number of units: {} ({} treated {}%, {} control {}%)'.format(
            len(self),sum(self.feat['treatment']==1),100*round(sum(self.feat['treatment']==1)/len(self),2),
            sum(self.feat['treatment']==0),100*round(sum(self.feat['treatment']==0)/len(self),2)))
        info.append('The number of covariates: {}'.format(self.config['x_n_covariate']+self.config['v_n_covariate']))
        info.append('The number of treatments: {}'.format(len(np.unique(self.feat['treatment'].values))))
        info.append('The splits ratios {}: {}/{}/{}'.format(self.config['splits'],
                                                            len(self.train),len(self.val),len(self.test)))
        return '\n'.join(info)


class AbstractDataLoader(object):
    def __init__(self, config, dataset):

        self.config = config
        self.dataset = dataset

        self.t_field = config['t_field']
        self.x_field = config['x_field']
        self.v_field = config['v_field']
        self.uid_field = config['uid_field']
        self.vid_field = config['vid_field']
        self.yf_field = config['yf_field']
        self.ycf_field = config['ycf_field']


class RecommenderDataLoader(AbstractDataLoader):

    def __init__(self, config, dataset,batch_size=1024, shuffle=True):
        super().__init__(config, dataset)

        x = self.dataset[self.x_field].values
        v = self.dataset[self.v_field].values
        t = self.dataset[self.t_field].values
        y = self.dataset['yf'].values


        self.ds = TensorDataset(torch.from_numpy(x).float(),
                                torch.from_numpy(t.reshape(-1, 1)).int(),
                                torch.from_numpy(y.reshape(-1, 1)).float(),
                                torch.from_numpy(v).float())

        self.dl = DataLoader(self.ds, batch_size=batch_size, shuffle=shuffle)
        self.size = x.shape

    def get_X_size(self):
        return self.size

    def __len__(self):
        return len(self.dl)

    def __iter__(self):
        for b in iter(self.dl):
            yield b


